#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <stdlib.h>
#include <lua.h>
#include <stdio.h>

#include "malloc_hook.h"
#include "skynet.h"
#include "atomic.h"

// turn on MEMORY_CHECK can do more memory check, such as double free
// #define MEMORY_CHECK

#define MEMORY_ALLOCTAG 0x20140605
#define MEMORY_FREETAG 0x0badf00d

static ATOM_SIZET _used_memory = 0;
static ATOM_SIZET _memory_block = 0;

struct mem_data {
  ATOM_ULONG handle;
  ATOM_SIZET allocated;
};

struct mem_cookie {
  size_t size;
  uint32_t handle;
#ifdef MEMORY_CHECK
  uint32_t dogtag;
#endif
  uint32_t cookie_size; // should be the last
};

#define SLOT_SIZE 0x10000
#define PREFIX_SIZE sizeof(struct mem_cookie)

static struct mem_data mem_stats[SLOT_SIZE];

#ifndef NOUSE_JEMALLOC

#include "jemalloc.h"

// for skynet_lalloc use
#define raw_realloc je_realloc
#define raw_free je_free

static ATOM_SIZET *get_allocated_field(uint32_t handle) {
  int h = (int)(handle & (SLOT_SIZE - 1));
  struct mem_data *data = &mem_stats[h];
  uint32_t old_handle = data->handle;
  ssize_t old_alloc = (ssize_t)data->allocated;
  if (old_handle == 0 || old_alloc <= 0) {
    // data->allocated may less than zero, because it may not count at start.
    if (!ATOM_CAS_ULONG(&data->handle, old_handle, handle)) {
      return 0;
    }
    if (old_alloc < 0) {
      ATOM_CAS_SIZET(&data->allocated, (size_t)old_alloc, 0);
    }
  }
  if (data->handle != handle) {
    return 0;
  }
  return &data->allocated;
}

inline static void update_xmalloc_stat_alloc(uint32_t handle, size_t __n) {
  ATOM_FADD(&_used_memory, __n);
  ATOM_FINC(&_memory_block);
  ATOM_SIZET *allocated = get_allocated_field(handle);
  if (allocated) {
    ATOM_FADD(allocated, __n);
  }
}

inline static void update_xmalloc_stat_free(uint32_t handle, size_t __n) {
  ATOM_FSUB(&_used_memory, __n);
  ATOM_FDEC(&_memory_block);
  ATOM_SIZET *allocated = get_allocated_field(handle);
  if (allocated) {
    ATOM_FSUB(allocated, __n);
  }
}

inline static void *fill_prefix(char *ptr, size_t sz, uint32_t cookie_size) {
  uint32_t handle = skynet_current_handle();
  struct mem_cookie *p = (struct mem_cookie *)ptr;
  char *ret = ptr + cookie_size;
  p->size = sz;
  p->handle = handle;
#ifdef MEMORY_CHECK
  p->dogtag = MEMORY_ALLOCTAG;
#endif
  update_xmalloc_stat_alloc(handle, sz);
  memcpy(ret - sizeof(uint32_t), &cookie_size, sizeof(cookie_size));
  return ret;
}

inline static uint32_t get_cookie_size(char *ptr) {
  uint32_t cookie_size;
  memcpy(&cookie_size, ptr - sizeof(cookie_size), sizeof(cookie_size));
  return cookie_size;
}

inline static void *clean_prefix(char *ptr) {
  if (ptr == NULL) {
    fprintf(stderr, "[malloc_hook] Warning: clean_prefix(NULL)\n");
    return NULL;
  }

  uint32_t cookie_size = 0;
  if ((uintptr_t)ptr < sizeof(uint32_t)) {
    fprintf(stderr, "[malloc_hook] Warning: Invalid pointer %p (too small)\n",
            ptr);
    return NULL;
  }

  memcpy(&cookie_size, ptr - sizeof(cookie_size), sizeof(cookie_size));

  if (cookie_size == 0 || cookie_size > 512) {
    fprintf(stderr,
            "[malloc_hook] Warning: Invalid cookie_size (%u) for pointer %p\n",
            cookie_size, ptr);
    return NULL;
  }

  struct mem_cookie *p = (struct mem_cookie *)(ptr - cookie_size);

#ifdef MEMORY_CHECK
  uint32_t dogtag = p->dogtag;
  if (dogtag != MEMORY_ALLOCTAG) {
    fprintf(
        stderr,
        "[malloc_hook] Warning: dogtag mismatch (possible double free) at %p\n",
        ptr);
    return NULL;
  }
  p->dogtag = MEMORY_FREETAG;
#endif

  update_xmalloc_stat_free(p->handle, p->size);
  return p;
}

static void malloc_oom(size_t size) {
  fprintf(stderr, "xmalloc: Out of memory trying to allocate %zu bytes\n",
          size);
  fflush(stderr);
  abort();
}

void memory_info_dump(const char *opts) { je_malloc_stats_print(0, 0, opts); }

bool mallctl_bool(const char *name, bool *newval) {
  bool v = 0;
  size_t len = sizeof(v);
  if (newval) {
    je_mallctl(name, &v, &len, newval, sizeof(bool));
  } else {
    je_mallctl(name, &v, &len, NULL, 0);
  }
  return v;
}

int mallctl_cmd(const char *name) {
  return je_mallctl(name, NULL, NULL, NULL, 0);
}

size_t mallctl_int64(const char *name, size_t *newval) {
  size_t v = 0;
  size_t len = sizeof(v);
  if (newval) {
    je_mallctl(name, &v, &len, newval, sizeof(size_t));
  } else {
    je_mallctl(name, &v, &len, NULL, 0);
  }
  // skynet_error(NULL, "name: %s, value: %zd\n", name, v);
  return v;
}

int mallctl_opt(const char *name, int *newval) {
  int v = 0;
  size_t len = sizeof(v);
  if (newval) {
    int ret = je_mallctl(name, &v, &len, newval, sizeof(int));
    if (ret == 0) {
      skynet_error(NULL, "set new value(%d) for (%s) succeed\n", *newval, name);
    } else {
      skynet_error(NULL, "set new value(%d) for (%s) failed: error -> %d\n",
                   *newval, name, ret);
    }
  } else {
    je_mallctl(name, &v, &len, NULL, 0);
  }

  return v;
}

// hook : malloc, realloc, free, calloc

void *skynet_malloc(size_t size) {
  void *ptr = je_malloc(size + PREFIX_SIZE);
  if (!ptr)
    malloc_oom(size);
  return fill_prefix(ptr, size, PREFIX_SIZE);
}

void *skynet_realloc(void *ptr, size_t size) {
  if (ptr == NULL)
    return skynet_malloc(size);

  uint32_t cookie_size = get_cookie_size(ptr);
  void *rawptr = clean_prefix(ptr);
  void *newptr = je_realloc(rawptr, size + cookie_size);
  if (!newptr)
    malloc_oom(size);
  return fill_prefix(newptr, size, cookie_size);
}

void skynet_free(void *ptr) {
  if (ptr == NULL)
    return;
  void *rawptr = clean_prefix(ptr);
  if (rawptr) {
    je_free(rawptr);
  } else {
    fprintf(stderr, "[malloc_hook] Warning: skip free for invalid pointer %p\n",
            ptr);
  }
}

void *skynet_calloc(size_t nmemb, size_t size) {
  uint32_t cookie_n = (PREFIX_SIZE + size - 1) / size;
  void *ptr = je_calloc(nmemb + cookie_n, size);
  if (!ptr)
    malloc_oom(nmemb * size);
  return fill_prefix(ptr, nmemb * size, cookie_n * size);
}

static inline uint32_t alignment_cookie_size(size_t alignment) {
  if (alignment >= PREFIX_SIZE)
    return alignment;
  switch (alignment) {
  case 4:
    return (PREFIX_SIZE + 3) / 4 * 4;
  case 8:
    return (PREFIX_SIZE + 7) / 8 * 8;
  case 16:
    return (PREFIX_SIZE + 15) / 16 * 16;
  }
  return (PREFIX_SIZE + alignment - 1) / alignment * alignment;
}

void *skynet_memalign(size_t alignment, size_t size) {
  uint32_t cookie_size = alignment_cookie_size(alignment);
  void *ptr = je_memalign(alignment, size + cookie_size);
  if (!ptr)
    malloc_oom(size);
  return fill_prefix(ptr, size, cookie_size);
}

void *skynet_aligned_alloc(size_t alignment, size_t size) {
  uint32_t cookie_size = alignment_cookie_size(alignment);
  void *ptr = je_aligned_alloc(alignment, size + cookie_size);
  if (!ptr)
    malloc_oom(size);
  return fill_prefix(ptr, size, cookie_size);
}

int skynet_posix_memalign(void **memptr, size_t alignment, size_t size) {
  uint32_t cookie_size = alignment_cookie_size(alignment);
  int err = je_posix_memalign(memptr, alignment, size + cookie_size);
  if (err)
    malloc_oom(size);
  fill_prefix(*memptr, size, cookie_size);
  return err;
}

#else

// for skynet_lalloc use
#define raw_realloc realloc
#define raw_free free

void memory_info_dump(const char *opts) { skynet_error(NULL, "No jemalloc"); }

size_t mallctl_int64(const char *name, size_t *newval) {
  skynet_error(NULL, "No jemalloc : mallctl_int64 %s.", name);
  return 0;
}

int mallctl_opt(const char *name, int *newval) {
  skynet_error(NULL, "No jemalloc : mallctl_opt %s.", name);
  return 0;
}

bool mallctl_bool(const char *name, bool *newval) {
  skynet_error(NULL, "No jemalloc : mallctl_bool %s.", name);
  return 0;
}

int mallctl_cmd(const char *name) {
  skynet_error(NULL, "No jemalloc : mallctl_cmd %s.", name);
  return 0;
}

#endif

size_t malloc_used_memory(void) { return ATOM_LOAD(&_used_memory); }

size_t malloc_memory_block(void) { return ATOM_LOAD(&_memory_block); }

void dump_c_mem() {
  int i;
  size_t total = 0;
  skynet_error(NULL, "dump all service mem:");
  for (i = 0; i < SLOT_SIZE; i++) {
    struct mem_data *data = &mem_stats[i];
    if (data->handle != 0 && data->allocated != 0) {
      total += data->allocated;
      skynet_error(NULL, ":%08x -> %zdkb %db", data->handle,
                   data->allocated >> 10, (int)(data->allocated % 1024));
    }
  }
  skynet_error(NULL, "+total: %zdkb", total >> 10);
}

char *skynet_strdup(const char *str) {
  size_t sz = strlen(str);
  char *ret = skynet_malloc(sz + 1);
  memcpy(ret, str, sz + 1);
  return ret;
}

void *skynet_lalloc(void *ptr, size_t osize, size_t nsize) {
  if (nsize == 0) {
    raw_free(ptr);
    return NULL;
  } else {
    return raw_realloc(ptr, nsize);
  }
}

int dump_mem_lua(lua_State *L) {
  int i;
  lua_newtable(L);
  for (i = 0; i < SLOT_SIZE; i++) {
    struct mem_data *data = &mem_stats[i];
    if (data->handle != 0 && data->allocated != 0) {
      lua_pushinteger(L, data->allocated);
      lua_rawseti(L, -2, (lua_Integer)data->handle);
    }
  }
  return 1;
}

size_t malloc_current_memory(void) {
  uint32_t handle = skynet_current_handle();
  int i;
  for (i = 0; i < SLOT_SIZE; i++) {
    struct mem_data *data = &mem_stats[i];
    if (data->handle == (uint32_t)handle && data->allocated != 0) {
      return (size_t)data->allocated;
    }
  }
  return 0;
}

void skynet_debug_memory(const char *info) {
  // for debug use
  uint32_t handle = skynet_current_handle();
  size_t mem = malloc_current_memory();
  fprintf(stderr, "[:%08x] %s %p\n", handle, info, (void *)mem);
}
